{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ed421aca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n",
      "Loading...\n",
      "Done!\n",
      "To use the Graphein submodule graphein.protein.features.sequence.embeddings, you need to install: biovec \n",
      "biovec cannot be installed via conda\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.meshes:To use the Graphein submodule graphein.protein.meshes, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "INFO:rdkit:Enabling RDKit 2021.09.3 jupyter extensions\n",
      "[08:02:38Using backend: pytorch\n",
      "] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: libtorch_cuda_cpp.so: cannot open shared object file: No such file or directory\n",
      "Note that this task has obsolete PDBs, thus, we remove it in default. For fair comparison using other models, please retrieve the data.split object for the post removal splits!\n"
     ]
    }
   ],
   "source": [
    "from tdc.single_pred import Develop\n",
    "data = Develop(name = 'SAbDab_Chen')\n",
    "split = data.get_split()\n",
    "\n",
    "gr = data.graphein(graph = 'distance', node_feature = ['amino_acid_one_hot'], \n",
    "                   distance_threshold = 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "12787677",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "InMemoryProteinGraphDataset(1670)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gr['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "805aa8a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch(amino_acid_one_hot=[16], batch=[7458], coords=[16], edge_index=[2, 25827], graph_y=[16], node_id=[16], ptr=[17])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "# Create dataloaders\n",
    "train_loader = DataLoader(gr['train'], batch_size=16, shuffle=False, drop_last=True)\n",
    "for b in train_loader:\n",
    "    print(b)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0e9e8c3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note that this task has obsolete PDBs, thus, we remove it in default. For fair comparison using other models, please retrieve the data.split object for the post removal splits!\n"
     ]
    }
   ],
   "source": [
    "from graphein.ml.conversion import GraphFormatConvertor\n",
    "import graphein.protein as gp\n",
    "\n",
    "graphein_config = gp.ProteinGraphConfig()\n",
    "convertor = GraphFormatConvertor(src_format=\"nx\", dst_format=\"pyg\")\n",
    "\n",
    "gr = data.graphein(config = graphein_config, convertor = convertor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bab9ca98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train':      Antibody_ID                                           Antibody  Y\n",
       " 0           12e8  ['EVQLQQSGAEVVRSGASVKLSCTASGFNIKDYYIHWVKQRPEKG...  0\n",
       " 1           15c8  ['EVQLQQSGAELVKPGASVKLSCTASGFNIKDTYMHWVKQKPEQG...  0\n",
       " 2           1a0q  ['EVQLQESDAELVKPGASVKISCKASGYTFTDHVIHWVKQKPEQG...  1\n",
       " 3           1a14  ['QVQLQQSGAELVKPGASVRMSCKASGYTFTNYNMYWVKQSPGQG...  0\n",
       " 4           1a2y  ['QVQLQESGPGLVAPSQSLSITCTVSGFSLTGYGVNWVRQPPGKG...  0\n",
       " ...          ...                                                ... ..\n",
       " 1681        6rcs  ['QVQLVQSGAEVKKPGASVRVSCKASGYTFTSYGISWVRQAPGQG...  0\n",
       " 1682        6s5a  ['EVKLLESGGGLVQPGGSLKLSCAASGFDFSRYWMNWVRQAPGKG...  0\n",
       " 1683        6u1t  ['EVQLVESGGGLVKPGGSLKLSCAASGFTFSSYDMSWVRQTPEKR...  0\n",
       " 1684        7fab  ['AVQLEQSGPGLVRPSQTLSLTCTVSGTSFDDYYWTWVRQPPGRG...  0\n",
       " 1685        8fab  ['AVKLVQAGGGVVQPGRSLRLSCIASGFTFSNYGMHWVRQAPGKG...  0\n",
       " \n",
       " [1670 rows x 3 columns],\n",
       " 'valid':     Antibody_ID                                           Antibody  Y\n",
       " 0          1cfq  ['QDQLQQSGAELVRPGASVKLSCKALGYIFTDYEIHWVKQTPVHG...  0\n",
       " 1          5dd6  ['QVQLQESGPGLVKPSETLSVTCAVSGVSFSSFWWGWIRQSPGKG...  0\n",
       " 2          4wfe  ['EVQLQQSGPELVKPGASMKTSCKVSGYSFTGYIMNWVKQRHGKN...  0\n",
       " 3          4hii  ['LINLVESGGGVVQPGRSLRLSCAASGFTFSRYGMHWVRQAPGKG...  0\n",
       " 4          1u92  ['RITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPG...  0\n",
       " ..          ...                                                ... ..\n",
       " 236        1etz  ['QVTLKESGPGILQPSQTLSLTCSFSGFSLSTSGMGVGWIRQPSG...  0\n",
       " 237        5f6h  ['EVQLVESGGGLVQPGGSLRLSCAASGFTFSNSGMIWVRQAPGKG...  0\n",
       " 238        6arp  ['QVQLKQSGPGLVQPSQSLSITCTVSGFDLTDYGVHWVRQSPGKG...  0\n",
       " 239        4a6y  ['TVQLQQPGAELVKPGASVKLSCKASGYTFTSYWMHWVKQRPGRG...  0\n",
       " 240        5b4m  ['DVKLVESGGGLVNLGGSLKLSCAASGFTFSRYYMSWVRQTPEKR...  0\n",
       " \n",
       " [241 rows x 3 columns],\n",
       " 'test':     Antibody_ID                                           Antibody  Y\n",
       " 0          5wtg  ['EVKLVESGGGSVKPGGSLKLSCAASGFSFSTYGMSWVRQTPEKR...  0\n",
       " 1          6ayn  ['QVQLKQSGPGLVQPSQSLSITCTVSGFSLTNYGVHWVRQSPGKG...  0\n",
       " 2          5gkr  ['QVQLQESGPGLVKSSETLSLTCTVSGGSISSYFWSWIRQPPGKG...  1\n",
       " 3          6m87  ['EFQLLESGPELVKPGASVKISCKASGYSFTDYNMNWVKQSNGKS...  0\n",
       " 4          3gm0  ['EVQLQESGPSLVKPSQTLSLTCSVTGDSVTSGYWSWIRQFPGNK...  0\n",
       " ..          ...                                                ... ..\n",
       " 477        5f72  ['EVQLVESGGGLVQPGGSLRLSCAASGFAISASSIHWVRQAPGKC...  0\n",
       " 478        2fjg  ['EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKG...  0\n",
       " 479        5i8c  ['QEVLVQSGAEVKKPGASVKVSCRAFGYTFTGNALHWVRQAPGQG...  1\n",
       " 480        6ejg  ['QVQLQQSGPELVKPGASVKISCKASGYTFSSSWMNWVKQRPGKG...  0\n",
       " 481        2iff  ['VQLQQSGAELMKPGASVKISCKASGYTFSDYWIEWVKQRPGHGL...  0\n",
       " \n",
       " [478 rows x 3 columns]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.split"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
